# pylint:disable=too-many-boolean-expressions
from __future__ import annotations
from typing import Any

import capstone

from angr.knowledge_plugins.functions import Function


def is_function_security_check_cookie(func, project, security_cookie_addr: int) -> bool:
    # disassemble the first instruction
    if func.is_plt or func.is_syscall or func.is_simprocedure:
        return False
    block = project.factory.block(func.addr)
    if block.instructions != 2:
        return False
    if not block.capstone.insns or len(block.capstone.insns) != 2:
        return False
    ins0 = block.capstone.insns[0]
    if (
        project.arch.name == "AMD64"
        and ins0.mnemonic == "cmp"
        and len(ins0.operands) == 2
        and ins0.operands[0].type == capstone.x86.X86_OP_REG
        and ins0.operands[0].reg == capstone.x86.X86_REG_RCX
        and ins0.operands[1].type == capstone.x86.X86_OP_MEM
        and ins0.operands[1].mem.base == capstone.x86.X86_REG_RIP
        and ins0.operands[1].mem.index == 0
        and ins0.operands[1].mem.disp + ins0.address + ins0.size == security_cookie_addr
    ):
        ins1 = block.capstone.insns[1]
        if ins1.mnemonic == "jne":
            return True
    if (
        project.arch.name == "X86"
        and ins0.mnemonic == "cmp"
        and len(ins0.operands) == 2
        and ins0.operands[0].type == capstone.x86.X86_OP_REG
        and ins0.operands[0].reg == capstone.x86.X86_REG_ECX
        and ins0.operands[1].type == capstone.x86.X86_OP_MEM
        and ins0.operands[1].mem.base == 0
        and ins0.operands[1].mem.disp == security_cookie_addr
        and ins0.operands[1].mem.index == 0
    ):
        ins1 = block.capstone.insns[1]
        if ins1.mnemonic == "jne":
            return True
    return False


def is_function_security_check_cookie_strict(func: Function, project) -> tuple[bool, int | None]:
    # security_cookie_addr is unavailable; we examine all bytes in this function
    if func.is_plt or func.is_syscall or func.is_simprocedure:
        return False, None
    if len(func.block_addrs_set) not in {5, 6}:
        return False, None
    block_bytes: list[tuple[int, Any, bytes]] = [
        (b.addr, b, b.bytes)
        for b in sorted(func.blocks, key=lambda b: b.addr)
        if isinstance(b.addr, int) and b.bytes is not None
    ]
    if block_bytes[0][0] != func.addr:
        # the first block is probably the BugCheck function - skip it
        block_bytes = block_bytes[1:]
    elif len(block_bytes) == 6:
        # skip the last block, which is probably the BugCheck function
        block_bytes = block_bytes[:-1]
    if len(block_bytes) != 5:
        return False, None

    # check the first block
    # cmp  rcx, [xxx]
    # jnz  xxx
    first_block = block_bytes[0][1]
    if len(first_block.capstone.insns) != 2:
        return False, None
    ins0 = first_block.capstone.insns[0]
    security_cookie_addr = None
    if (
        project.arch.name == "AMD64"
        and ins0.mnemonic == "cmp"
        and len(ins0.operands) == 2
        and ins0.operands[0].type == capstone.x86.X86_OP_REG
        and ins0.operands[0].reg == capstone.x86.X86_REG_RCX
        and ins0.operands[1].type == capstone.x86.X86_OP_MEM
        and ins0.operands[1].mem.base == capstone.x86.X86_REG_RIP
        and ins0.operands[1].mem.index == 0
    ):
        ins1 = first_block.capstone.insns[1]
        if ins1.mnemonic == "jne":
            security_cookie_addr = ins0.operands[1].mem.disp + ins0.address + ins0.size
    if (
        project.arch.name == "X86"
        and ins0.mnemonic == "cmp"
        and len(ins0.operands) == 2
        and ins0.operands[0].type == capstone.x86.X86_OP_REG
        and ins0.operands[0].reg == capstone.x86.X86_REG_ECX
        and ins0.operands[1].type == capstone.x86.X86_OP_MEM
        and ins0.operands[1].mem.base == 0
        and ins0.operands[1].mem.index == 0
    ):
        ins1 = first_block.capstone.insns[1]
        if ins1.mnemonic == "jne":
            security_cookie_addr = ins0.operands[1].mem.disp

    if security_cookie_addr is None:
        return False, None

    # the last block should be a jump
    last_block = block_bytes[-1][1]
    if len(last_block.capstone.insns) != 1:
        return False, None
    last_insn = last_block.capstone.insns[-1]
    if last_insn.mnemonic != "jmp":
        return False, None

    # check the bytes of the remaining three blocks
    if project.arch.name == "AMD64":
        expected_bytes = [b"\x48\xc1\xc1\x10\x66\xf7\xc1\xff\xff\x75\x01", b"\xc3", b"\x48\xc1\xc9\x10"]
    else:
        # TODO: x86 bytes
        expected_bytes = []

    existing_bytes = []
    for i, b in enumerate(block_bytes[1:-1]):
        block = b[2]
        max_block_size = block_bytes[1 + i + 1][0] - b[0]
        existing_bytes.append(block[:max_block_size])
    # normalize the block bytes if needed
    if existing_bytes == expected_bytes:
        return True, security_cookie_addr
    return False, None


def is_function_security_init_cookie(func: Function, project, security_cookie_addr: int | None) -> bool:
    if func.is_plt or func.is_syscall or func.is_simprocedure:
        return False
    # the function should have only one return point
    if len(func.endpoints) == 1 and len(func.ret_sites) == 1:
        # the function is normalized
        ret_block = next(iter(func.ret_sites))
        preds = [(pred.addr, pred.size) for pred in func.graph.predecessors(ret_block)]
        if len(preds) != 2:
            return False
    elif len(func.endpoints) == 2 and len(func.ret_sites) == 2:
        # the function is not normalized
        ep0, ep1 = func.endpoints
        if ep0.addr > ep1.addr:
            ep0, ep1 = ep1, ep0
        if ep0.addr + ep0.size == ep1.addr + ep1.size and ep0.addr < ep1.addr:
            # overlapping block
            preds = [(ep0.addr, ep1.addr - ep0.addr)]
        else:
            return False
    else:
        return False
    for node_addr, node_size in preds:
        # lift the block and check the last instruction
        block = project.factory.block(node_addr, size=node_size)
        if not block.instructions:
            continue
        if not block.capstone.insns:
            continue
        last_insn = block.capstone.insns[-1]
        if (
            project.arch.name == "AMD64"
            and last_insn.mnemonic == "mov"
            and len(last_insn.operands) == 2
            and last_insn.operands[0].type == capstone.x86.X86_OP_MEM
            and last_insn.operands[0].mem.base == capstone.x86.X86_REG_RIP
            and last_insn.operands[0].mem.index == 0
            and last_insn.operands[0].mem.disp + last_insn.address + last_insn.size == security_cookie_addr
            and last_insn.operands[1].type == capstone.x86.X86_OP_REG
        ) or (
            project.arch.name == "X86"
            and last_insn.mnemonic == "mov"
            and len(last_insn.operands) == 2
            and last_insn.operands[0].type == capstone.x86.X86_OP_MEM
            and last_insn.operands[0].mem.base == 0
            and last_insn.operands[0].mem.index == 0
            and last_insn.operands[0].mem.disp == security_cookie_addr
            and last_insn.operands[1].type == capstone.x86.X86_OP_REG
        ):
            return True
    return False


def is_function_security_init_cookie_win8(func: Function, project, security_cookie_addr: int) -> bool:
    # disassemble the first instruction
    if func.is_plt or func.is_syscall or func.is_simprocedure:
        return False
    block = project.factory.block(func.addr)
    if block.instructions != 3:
        return False
    if not block.capstone.insns or len(block.capstone.insns) != 3:
        return False
    ins0 = block.capstone.insns[0]
    if (
        ins0.mnemonic == "mov"
        and len(ins0.operands) == 2
        and ins0.operands[0].type == capstone.x86.X86_OP_REG
        and ins0.operands[0].reg == capstone.x86.X86_REG_RAX
        and ins0.operands[1].type == capstone.x86.X86_OP_MEM
        and ins0.operands[1].mem.base == capstone.x86.X86_REG_RIP
        and ins0.operands[1].mem.index == 0
        and ins0.operands[1].mem.disp + ins0.address + ins0.size == security_cookie_addr
    ):
        ins1 = block.capstone.insns[-1]
        if ins1.mnemonic == "je":
            succs = list(func.graph.successors(func.get_node(block.addr)))
            if len(succs) > 2:
                return False
            for succ in succs:
                succ_block = project.factory.block(succ.addr)
                if succ_block.instructions:
                    first_insn = succ_block.capstone.insns[0]
                    if (
                        first_insn.mnemonic == "movabs"
                        and len(first_insn.operands) == 2
                        and first_insn.operands[1].type == capstone.x86.X86_OP_IMM
                        and first_insn.operands[1].imm == 0x2B992DDFA232
                    ):
                        return True
    return False


def is_function_likely_security_init_cookie(func: Function) -> bool:
    """
    Conducts a fuzzy match for security_init_cookie function.
    """

    callees = [node for node in func.transition_graph if isinstance(node, Function)]
    callee_names = {callee.name for callee in callees}
    return bool(
        callee_names.issuperset(
            {
                "GetSystemTimeAsFileTime",
                "GetCurrentProcessId",
                "GetCurrentThreadId",
                "GetTickCount",
                "QueryPerformanceCounter",
            }
        )
    )
